-
-
Notifications
You must be signed in to change notification settings - Fork 108
add MADDPG algorithm #444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add MADDPG algorithm #444
Conversation
|
You can simply add those words after ReinforcementLearning.jl/.cspell/cspell.json Line 123 in 4973762
cc @pilgrimygy |
|
Thanks! And ask for suggestions about the implementation and code mistakes/style. |
|
I'll review it later tonight 😃 |
findmyway
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine to me
| function (π::MADDPGManager)(::PreEpisodeStage, ::AbstractEnv) | ||
| for (_, agent) in π.agents | ||
| if length(agent.trajectory) > 0 | ||
| pop!(agent.trajectory[:state]) | ||
| pop!(agent.trajectory[:action]) | ||
| if haskey(agent.trajectory, :legal_actions_mask) | ||
| pop!(agent.trajectory[:legal_actions_mask]) | ||
| end | ||
| end | ||
| end | ||
| end | ||
|
|
||
| function (π::MADDPGManager)(::PreActStage, env::AbstractEnv, actions) | ||
| # update each agent's trajectory | ||
| for (player, agent) in π.agents | ||
| push!(agent.trajectory[:state], state(env, player)) | ||
| push!(agent.trajectory[:action], actions[player]) | ||
| if haskey(agent.trajectory, :legal_actions_mask) | ||
| lasm = legal_action_space_mask(env, player) | ||
| push!(agent.trajectory[:legal_actions_mask], lasm) | ||
| end | ||
| end | ||
|
|
||
| # update policy | ||
| update!(π) | ||
| end | ||
|
|
||
| function (π::MADDPGManager)(::PostActStage, env::AbstractEnv) | ||
| for (player, agent) in π.agents | ||
| push!(agent.trajectory[:reward], reward(env, player)) | ||
| push!(agent.trajectory[:terminal], is_terminated(env)) | ||
| end | ||
| end | ||
|
|
||
| function (π::MADDPGManager)(::PostEpisodeStage, env::AbstractEnv) | ||
| # collect state and dummy action to each agent's trajectory | ||
| for (player, agent) in π.agents | ||
| push!(agent.trajectory[:state], state(env, player)) | ||
| push!(agent.trajectory[:action], rand(action_space(env))) | ||
| if haskey(agent.trajectory, :legal_actions_mask) | ||
| lasm = legal_action_space_mask(env, player) | ||
| push!(agent.trajectory[:legal_actions_mask], lasm) | ||
| end | ||
| end | ||
|
|
||
| # update policy | ||
| update!(π) | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about dispatching to the inner agent's corresponding methods?
Like calling agent(stage, env, action) in the for loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you take a look at the NamedPolicy and see whether we can reuse existing code as much as possible? See also the MultiAgentManager
| temp_player = rand(keys(π.agents)) | ||
| t = π.agents[temp_player].trajectory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simply use the first agent?
| temp_player = rand(keys(π.agents)) | ||
| t = π.agents[temp_player].trajectory | ||
| inds = rand(π.rng, 1:length(t), π.batch_size) | ||
| batches = Dict((player, RLCore.fetch!(BatchSampler{SARTS}(π.batch_size), agent.trajectory, inds)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hardcoded SARTS will make the algorithm work only on environments of MINIMAL_ACTION_SET.
| s = vcat((batches[player][1] for (player, _) in π.agents)...) | ||
| a = vcat((batches[player][2] for (player, _) in π.agents)...) | ||
| s′ = vcat((batches[player][5] for (player, _) in π.agents)...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vcat is not very efficient here. Try Flux.batch?
| s, a, s′ = send_to_host((s, a, s′)) | ||
| mu_actions = send_to_host(mu_actions) | ||
| new_actions = send_to_host(new_actions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are they required here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your kind reviews! I'll check and update my codes later today.
|
Here is still a simple version of |
PR Checklist
The description of the implementation is in discussion #404.
Here
MADDPGraises anunknown worderror... How can I fix it? @findmyway